import torch
import warnings
import numpy as np
import torch.nn as nn
from visdom import Visdom
import torch.optim as optim
import torchxrayvision as xrv
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
from torch.optim.lr_scheduler import StepLR
import os

from conduct import val
from conduct import train
from dataload import CovidCTDataset

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

train_transformer = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop((224), scale=(0.5, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])

val_transformer = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])

if __name__ == '__main__':

    batchsize = 5
    total_epoch = 1000
    votenum = 10

    trainset = CovidCTDataset(root_dir='data',
                              txt_COVID='data/trainCT_COVID.txt',
                              txt_NonCOVID='data/trainCT_NonCOVID.txt',
                              transform=train_transformer)
    valset = CovidCTDataset(root_dir='data',
                            txt_COVID='data/valCT_COVID.txt',
                            txt_NonCOVID='data/valCT_NonCOVID.txt',
                            transform=val_transformer)
    testset = CovidCTDataset(root_dir='data',
                             txt_COVID='data/testCT_COVID.txt',
                             txt_NonCOVID='data/testCT_NonCOVID.txt',
                             transform=val_transformer)

    print(trainset.__len__())
    print(valset.__len__())
    print(testset.__len__())

    train_loader = DataLoader(trainset, batch_size=batchsize, drop_last=False, shuffle=True)
    val_loader = DataLoader(valset, batch_size=batchsize, drop_last=False, shuffle=False)
    test_loader = DataLoader(testset, batch_size=batchsize, drop_last=False, shuffle=False)

    model = xrv.models.DenseNet(num_classes=2, in_channels=3).cuda()
    modelname = 'DenseNet_medical'
    torch.cuda.empty_cache()

    criteria = nn.CrossEntropyLoss()

    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

    viz = Visdom(server='http://localhost/', port=8097)

    viz.line([[0., 0., 0., 0., 0.]], [0], win='train_performance', update='replace', opts=dict(title='train_performance', legend=['precision', 'recall', 'AUC', 'F1', 'acc']))
    viz.line([[0., 0.]], [0], win='train_Loss', update='replace', opts=dict(title='train_Loss', legend=['train_loss', 'val_loss']))

    warnings.filterwarnings('ignore')

    TP = 0
    TN = 0
    FN = 0
    FP = 0

    r_list = []
    p_list = []
    acc_list = []
    AUC_list = []

    vote_pred = np.zeros(valset.__len__())
    vote_score = np.zeros(valset.__len__())

    for epoch in range(1, total_epoch + 1):

        train_loss = train(optimizer, epoch, model, train_loader, modelname, criteria)
        targetlist, scorelist, predlist, val_loss = val(model, val_loader, criteria)
        print('target', targetlist)
        print('score', scorelist)
        print('predict', predlist)
        vote_pred = vote_pred + predlist
        vote_score = vote_score + scorelist
        if epoch % votenum == 0:

            # major vote
            vote_pred[vote_pred <= (votenum / 2)] = 0
            vote_pred[vote_pred > (votenum / 2)] = 1
            vote_score = vote_score / votenum

            print('vote_pred', vote_pred)
            print('targetlist', targetlist)

            TP = ((vote_pred == 1) & (targetlist == 1)).sum()
            TN = ((vote_pred == 0) & (targetlist == 0)).sum()
            FN = ((vote_pred == 0) & (targetlist == 1)).sum()
            FP = ((vote_pred == 1) & (targetlist == 0)).sum()

            print('TP=', TP, 'TN=', TN, 'FN=', FN, 'FP=', FP)
            print('TP+FP', TP + FP)
            p = TP / (TP + FP)
            print('precision', p)
            p = TP / (TP + FP)
            r = TP / (TP + FN)
            print('recall', r)
            F1 = 2 * r * p / (r + p)
            acc = (TP + TN) / (TP + TN + FP + FN)
            print('F1', F1)
            print('acc', acc)
            AUC = roc_auc_score(targetlist, vote_score)
            print('AUC', AUC)

            # 训练过程可视化
            train_loss = train_loss.cpu().detach().numpy()
            val_loss = val_loss.cpu().detach().numpy()
            viz.line([[p, r, AUC, F1, acc]], [epoch], win='train_performance', update='append',
                     opts=dict(title='train_performance', legend=['precision', 'recall', 'AUC', 'F1', 'acc']))
            viz.line([[train_loss], [val_loss]], [epoch], win='train_Loss', update='append',
                     opts=dict(title='train_Loss', legend=['train_loss', 'val_loss']))

            print(
                'The epoch is {}, average recall: {:.4f}, average precision: {:.4f},average F1: {:.4f}, '
                'average accuracy: {:.4f}, average AUC: {:.4f}'.format(
                    epoch, r, p, F1, acc, AUC))

            if os.path.exists('backup') == 0:
                os.makedirs('backup')
            torch.save(model.state_dict(), "backup/{}.pt".format(modelname))

            vote_pred = np.zeros(valset.__len__())
            vote_score = np.zeros(valset.__len__())
            f = open('performance/{}.txt'.format(modelname), 'a+')
            f.write(
                '\n The epoch is {}, average recall: {:.4f}, average precision: {:.4f},average F1: {:.4f}, '
                'average accuracy: {:.4f}, average AUC: {:.4f}'.format(
                    epoch, r, p, F1, acc, AUC))
            f.close()
        if epoch % (votenum*10) == 0:
            torch.save(model.state_dict(), "backup/{}_epoch{}.pt".format(modelname, epoch))
